#include "AiUnitTests.h"

#include <cassert>

#include "ai/steering/SteeringArrive.h"
#include "ai/steering/SteeringEvade.h"
#include "ai/steering/SteeringFlee.h"
#include "ai/steering/SteeringOutput.h"
#include "ai/steering/SteeringPursue.h"
#include "ai/steering/SteeringSeek.h"
#include "ai/steering/SteeringVelocityMatch.h"
#include "ai/steering/SteeringWander.h"
#include "math/Point3.h"
#include "math/Vector3.h"
#include "math/Utils.h"

namespace {
    void steeringOutputUnitTests() {
        // SteeringOutput::Constructor
        {
            const Vector3 linearVel(1.0f, 0.0f, 1.0f);
            const float angularVel = 15.0f;
            const SteeringOutput steeringOutput(linearVel);

            assert(linearVel == steeringOutput.mLinearVel);
        }
    }

    void steeringSeekUnitTests() {
        // SeekData::Constructor:
        {
            const Point3 src(1.0f, 0.0f, 1.0f);
            const Point3 dest(0.0f, 0.0f, 10.0f);
            const float speed = 5.0f;
            const SeekData seekData(&src, &dest, speed);

            assert(src == *seekData.mSrcPos);
            assert(dest == *seekData.mTargetPos);
            assert(areEquals(speed, seekData.mSpeed) == true);
        }

        // set():
        {
            const Point3 src(1.0f, 0.0f, 1.0f);
            const Point3 dest(0.0f, 0.0f, 10.0f);
            const float speed = 5.0f;
            SeekData seekData;
            seekData.set(src, dest, speed);

            assert(src == *seekData.mSrcPos);
            assert(dest == *seekData.mTargetPos);
            assert(areEquals(speed, seekData.mSpeed) == true);
        }

        // seek():
        {
            SeekData seekData;
            SteeringOutput output;

            const Point3 src1(0.0f, 0.0f, 0.0f);
            const Point3 dest1(0.0f, 0.0f, 0.0f);
            const float speed1 = 100.0f;
            seekData.set(src1, dest1, speed1);
            seek(&seekData, &output, 1);
            const Vector3 linearVel1(0.0f, 0.0f, 0.0f);
            assert(linearVel1 == output.mLinearVel);

            const Point3 src2(0.0f, 0.0f, 0.0f);
            const Point3 dest2(0.0f, 0.0f, 1.0f);
            const float speed2 = 100.0f;
            seekData.set(src2, dest2, speed2);
            seek(&seekData, &output, 1);
            const Vector3 linearVel2(0.0f, 0.0f, 100.0f);
            assert(linearVel2 == output.mLinearVel);

            const Point3 src3(10.0f, 0.0f, 0.0f);
            const Point3 dest3(0.0f, 0.0f, 0.0f);
            const float speed3 = 1.0f;
            seekData.set(src3, dest3, speed3);
            seek(&seekData, &output, 1);
            const Vector3 linearVel3(-1.0f, 0.0f, 0.0f);
            assert(linearVel3 == output.mLinearVel);
        }
    }

    void steeringFleeUnitTests() {
        // FleeData::Constructor:
        {
            const Point3 src(1.0f, 0.0f, 1.0f);
            const Point3 dest(0.0f, 0.0f, 10.0f);
            const float speed = 5.0f;
            const FleeData fleeData(&src, &dest, speed);

            assert(src == *fleeData.mSrcPos);
            assert(dest == *fleeData.mTargetPos);

            assert(areEquals(speed, fleeData.mSpeed) == true);
        }

        // set():
        {
            const Point3 src(1.0f, 0.0f, 1.0f);
            const Point3 dest(0.0f, 0.0f, 10.0f);
            const float speed = 5.0f;
            FleeData fleeData;
            fleeData.set(src, dest, speed);

            assert(src == *fleeData.mSrcPos);
            assert(dest == *fleeData.mTargetPos);
            assert(areEquals(speed, fleeData.mSpeed) == true);
        }

        // flee():
        {
            FleeData fleeData;
            SteeringOutput output;

            const Point3 src1(0.0f, 0.0f, 0.0f);
            const Point3 dest1(0.0f, 0.0f, 0.0f);
            const float speed1 = 100.0f;
            fleeData.set(src1, dest1, speed1);
            flee(&fleeData, &output, 1);
            const Vector3 linearVel1(0.0f, 0.0f, 0.0f);
            assert(linearVel1 == output.mLinearVel);

            const Point3 src2(0.0f, 0.0f, 0.0f);
            const Point3 dest2(0.0f, 0.0f, 1.0f);
            const float speed2 = 100.0f;
            fleeData.set(src2, dest2, speed2);
            flee(&fleeData, &output, 1);
            const Vector3 linearVel2(0.0f, 0.0f, -100.0f);
            assert(linearVel2 == output.mLinearVel);

            const Point3 src3(10.0f, 0.0f, 0.0f);
            const Point3 dest3(0.0f, 0.0f, 0.0f);
            const float speed3 = 1.0f;
            fleeData.set(src3, dest3, speed3);
            flee(&fleeData, &output, 1);
            const Vector3 linearVel3(1.0f, 0.0f, 0.0f);
            assert(linearVel3 == output.mLinearVel);
        }
    }

    void steeringArriveUnitTests() {
        // ArriveData::Constructor:
        {
            const Point3 srcPos(1.0f, 0.0f, 1.0f);
            const Vector3 srcLinearVel(10.0f, 0.0f, 0.0f);
            const Point3 targetPos(0.0f, 0.0f, 10.0f);
            const float targetRadius = 10.0f;
            const float slowDownRadius = 8.0f;
            const float timeToTarget = 0.25f;
            const ArriveData data(&srcPos,
                                      &srcLinearVel,
                                      &targetPos, 
                                      targetRadius, 
                                      slowDownRadius,
                                      timeToTarget);

            assert(srcPos == *data.mSrcPos);
            assert(srcLinearVel == *data.mSrcLinearVel);
            assert(targetPos == *data.mTargetPos);
            assert(areEquals(targetRadius, data.mTargetRadius) == true);
            assert(areEquals(slowDownRadius, data.mSlowDownRadius) == true);
            assert(areEquals(timeToTarget, data.mTimeToTarget) == true);
        }

        // set():
        {
            const Point3 srcPos(1.0f, 0.0f, 1.0f);
            const Vector3 srcLinearVel(10.0f, 0.0f, 0.0f);
            const Point3 targetPos(0.0f, 0.0f, 10.0f);
            const float targetRadius = 10.0f;
            const float slowDownRadius = 8.0f;
            const float timeToTarget = 0.25f;
            ArriveData data;

            data.set(srcPos,
                     srcLinearVel,
                     targetPos,
                     targetRadius, 
                     slowDownRadius,
                     timeToTarget);

            assert(srcPos == *data.mSrcPos);
            assert(srcLinearVel == *data.mSrcLinearVel);
            assert(targetPos == *data.mTargetPos);
            assert(areEquals(targetRadius, data.mTargetRadius) == true);
            assert(areEquals(slowDownRadius, data.mSlowDownRadius) == true);
            assert(areEquals(timeToTarget, data.mTimeToTarget) == true);
        }
    }

    void steeringVelocityMatch() {
        // VelocityMatchData::Constructor:
        {
            const Vector3 srcLinearVel(10.0f, 0.0f, 0.0f);
            const Vector3 targetLinearVel(0.0f, 0.0f, 10.0f);
            const float timeToTarget = 0.25f;
            const VelocityMatchData data(&srcLinearVel,
                                             &targetLinearVel,
                                             timeToTarget);

            assert(srcLinearVel == *data.mSrcLinearVel);
            assert(targetLinearVel == *data.mTargetLinearVel);
            assert(areEquals(timeToTarget, data.mTimeToTarget) == true);
        }

        // set():
        {
            const Vector3 srcLinearVel(10.0f, 0.0f, 0.0f);
            const Vector3 targetLinearVel(0.0f, 0.0f, 10.0f);
            const float timeToTarget = 0.25f;
            VelocityMatchData data;
            data.set(srcLinearVel,
                     targetLinearVel,
                     timeToTarget);

            assert(srcLinearVel == *data.mSrcLinearVel);
            assert(targetLinearVel == *data.mTargetLinearVel);
            assert(areEquals(timeToTarget, data.mTimeToTarget) == true);
        }
    }

    void steeringPursueTests() {
        // PursueData::Constructor:
        {
            const Point3 srcPos(1.0f, 0.0f, 1.0f);
            const Vector3 srcLinearVel(10.0f, 0.0f, 0.0f);
            const Point3 targetPos(0.0f, 0.0f, 10.0f);
            const Vector3 targetLinearVel(10.0f, 0.0f, 0.0f);
            const float maxPredictionTime = 1.0f;
            const PursueData data(&srcPos,
                                      &srcLinearVel,
                                      &targetPos, 
                                      &targetLinearVel,  
                                      maxPredictionTime);

            assert(srcPos == *data.mSrcPos);
            assert(srcLinearVel == *data.mSrcLinearVel);
            assert(targetPos == *data.mTargetPos);
            assert(targetLinearVel == *data.mTargetLinearVel);
            assert(areEquals(maxPredictionTime, data.mMaxPredictionTime) == true);
        }

        // set():
        {
            const Point3 srcPos(1.0f, 0.0f, 1.0f);
            const Vector3 srcLinearVel(10.0f, 0.0f, 0.0f);
            const Point3 targetPos(0.0f, 0.0f, 10.0f);
            const Vector3 targetLinearVel(10.0f, 0.0f, 0.0f);
            const float maxPredictionTime = 1.0f;
            PursueData data;
            data.set(srcPos,
                     srcLinearVel,
                     targetPos, 
                     targetLinearVel, 
                     maxPredictionTime);

            assert(srcPos == *data.mSrcPos);
            assert(srcLinearVel == *data.mSrcLinearVel);
            assert(targetPos == *data.mTargetPos);
            assert(targetLinearVel == *data.mTargetLinearVel);
            assert(areEquals(maxPredictionTime, data.mMaxPredictionTime) == true);
        }
    }

    void steeringEvadeTests() {
        // EvadeData::Constructor:
        {
            const Point3 srcPos(1.0f, 0.0f, 1.0f);
            const Vector3 srcLinearVel(10.0f, 0.0f, 0.0f);
            const Point3 targetPos(0.0f, 0.0f, 10.0f);
            const Vector3 targetLinearVel(10.0f, 0.0f, 0.0f);
            const float maxPredictionTime = 1.0f;
            const EvadeData data(&srcPos,
                                     &srcLinearVel,
                                     &targetPos, 
                                     &targetLinearVel, 
                                     maxPredictionTime);

            assert(srcPos == *data.mSrcPos);
            assert(srcLinearVel == *data.mSrcLinearVel);
            assert(targetPos == *data.mTargetPos);
            assert(targetLinearVel == *data.mTargetLinearVel);
            assert(areEquals(maxPredictionTime, data.mMaxPredictionTime) == true);
        }

        // set():
        {
            const Point3 srcPos(1.0f, 0.0f, 1.0f);
            const Vector3 srcLinearVel(10.0f, 0.0f, 0.0f);
            const Point3 targetPos(0.0f, 0.0f, 10.0f);
            const Vector3 targetLinearVel(10.0f, 0.0f, 0.0f);
            const float maxPredictionTime = 1.0f;
            EvadeData data;
            data.set(srcPos,
                     srcLinearVel,
                     targetPos, 
                     targetLinearVel, 
                     maxPredictionTime);

            assert(srcPos == *data.mSrcPos);
            assert(srcLinearVel == *data.mSrcLinearVel);
            assert(targetPos == *data.mTargetPos);
            assert(targetLinearVel == *data.mTargetLinearVel);
            assert(areEquals(maxPredictionTime, data.mMaxPredictionTime) == true);
        }
    }

    void steeringWanderTests() {
        // WanderData::Constructor
        {
            const Point3 srcPos(1.0f, 2.0f, 3.0f);
            const Vector3 srcLinearVel(0.0f, 1.0f, 2.0f);
            const float srcOrientation = 10.0f;
            const float wanderOrientation = 0.0f;
            const WanderData data(&srcPos, &srcLinearVel, &srcOrientation);

            assert(srcPos == *data.mSrcPos);
            assert(srcLinearVel == *data.mSrcLinearVel);
            assert(areEquals(srcOrientation, *data.mSrcOrientation) == true);
            assert(areEquals(wanderOrientation, data.mWanderOrientation) == true);
        }

        // set
        {
            const Point3 srcPos(1.0f, 2.0f, 3.0f);
            const Vector3 srcLinearVel(0.0f, 1.0f, 2.0f);
            const float srcOrientation = 10.0f;
            const float wanderOrientation = 0.0f;
            WanderData data;
            data.set(srcPos, srcLinearVel, srcOrientation);

            assert(srcPos == *data.mSrcPos);
            assert(srcLinearVel == *data.mSrcLinearVel);
            assert(areEquals(srcOrientation, *data.mSrcOrientation) == true);
            assert(areEquals(wanderOrientation, data.mWanderOrientation) == true);
        }
    }
}

namespace aiUnitTests {
    void runTests() {
        steeringOutputUnitTests();
        steeringSeekUnitTests();
        steeringFleeUnitTests();
        steeringArriveUnitTests();
        steeringVelocityMatch();
        steeringPursueTests();
        steeringEvadeTests();
        steeringWanderTests();
    }
}